{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "56c899d2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-09T02:46:34.693761Z",
     "start_time": "2024-03-09T02:46:34.676304Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['http_proxy'] = 'http://127.0.0.1:7890'\n",
    "os.environ['https_proxy'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a694d592",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-09T02:46:43.632330Z",
     "start_time": "2024-03-09T02:46:41.679664Z"
    }
   },
   "outputs": [],
   "source": [
    "from torchvision.datasets import MNIST\n",
    "from torchvision.datasets import CIFAR100\n",
    "from torch.utils.data import ConcatDataset\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed144e64",
   "metadata": {},
   "source": [
    "- dataloader 是对 dataset 的进一步封装；"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5803d5d",
   "metadata": {},
   "source": [
    "## dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "62f3d74a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-06T12:48:46.003952Z",
     "start_time": "2024-02-06T12:48:39.631841Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-02-06 20:48:42,066] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "model = SentenceTransformer('all-MiniLM-L6-v2')\n",
    "train_examples = [\n",
    "    InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),\n",
    "    InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8502d372",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-06T13:01:43.028955Z",
     "start_time": "2024-02-06T13:01:43.020588Z"
    }
   },
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f49a251",
   "metadata": {},
   "source": [
    "### len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "04111c90",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-06T12:51:30.620722Z",
     "start_time": "2024-02-06T12:51:30.612762Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7f4de3a",
   "metadata": {},
   "source": [
    "### sampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5cd5660c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-06T12:49:06.359825Z",
     "start_time": "2024-02-06T12:49:06.349769Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.sampler.RandomSampler at 0x7fb09001d5a0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataloader.sampler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cb939e2",
   "metadata": {},
   "source": [
    "### collate_fn: 指定如何将一批数据样本组合成一个批次(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "074e0713-028e-449b-8396-4575bd5c0240",
   "metadata": {},
   "source": [
    "- list of `<x, y>` (`dataset.__get_item__`) => batch X tensor, batch y tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "74fbc0b1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-06T12:49:28.359529Z",
     "start_time": "2024-02-06T12:49:28.350433Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function torch.utils.data._utils.collate.default_collate(batch)>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataloader.collate_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93236945",
   "metadata": {},
   "source": [
    "```\n",
    "# from dataset => dataloader\n",
    "class _MapDatasetFetcher(_BaseDatasetFetcher):\n",
    "    def fetch(self, possibly_batched_index):\n",
    "        if self.auto_collation:\n",
    "            if hasattr(self.dataset, \"__getitems__\") and self.dataset.__getitems__:\n",
    "                data = self.dataset.__getitems__(possibly_batched_index)\n",
    "            else:\n",
    "                data = [self.dataset[idx] for idx in possibly_batched_index]\n",
    "        else:\n",
    "            data = self.dataset[possibly_batched_index]\n",
    "        return self.collate_fn(data)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c25f3428",
   "metadata": {},
   "source": [
    "### next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "430d8cc4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-06T13:00:56.935691Z",
     "start_time": "2024-02-06T13:00:56.930447Z"
    }
   },
   "outputs": [],
   "source": [
    "# next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d17e61a9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-06T13:01:46.538785Z",
     "start_time": "2024-02-06T13:01:46.512161Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([{'input_ids': tensor([[ 101, 2023, 2003, 1037, 3893, 3940,  102],\n",
       "           [ 101, 2023, 2003, 1037, 4997, 3940,  102]]),\n",
       "   'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],\n",
       "           [0, 0, 0, 0, 0, 0, 0]]),\n",
       "   'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],\n",
       "           [1, 1, 1, 1, 1, 1, 1]])},\n",
       "  {'input_ids': tensor([[  101,  2073,  1996,  3292,  2097,  2022, 18478,  2094,   102],\n",
       "           [  101,  2037,  3292,  2097,  2022,  3445,   102,     0,     0]]),\n",
       "   'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "           [0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n",
       "   'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "           [1, 1, 1, 1, 1, 1, 1, 0, 0]])}],\n",
       " tensor([1, 0]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataloader.collate_fn = model.smart_batching_collate\n",
    "next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f73b0bb1",
   "metadata": {},
   "source": [
    "## ConcatDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc315f1b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-09T02:47:48.517372Z",
     "start_time": "2024-03-09T02:47:28.704807Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist:  60000\n",
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 169001437/169001437 [00:14<00:00, 11771577.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/cifar-100-python.tar.gz to ./data\n",
      "cifar:  50000\n",
      "concat_data:  110000\n",
      "(28, 28)\n",
      "9\n"
     ]
    }
   ],
   "source": [
    "mnist_data = MNIST('./data/', train=True, download=True)\n",
    "print('mnist: ', len(mnist_data))\n",
    "cifar10_data = CIFAR100('./data', train=True, download=True)\n",
    "print('cifar: ', len(cifar10_data))\n",
    "concat_data = ConcatDataset([mnist_data, cifar10_data])\n",
    "print('concat_data: ', len(concat_data))\n",
    "img, target = concat_data.__getitem__(133)\n",
    "print(np.array(img).shape)\n",
    "print(target)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0df9900d-122c-4217-a0f7-1d483f8c86b1",
   "metadata": {},
   "source": [
    "## custom dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8f8fece6-8b12-4c48-97d7-fd3e5aa08e4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.datasets import make_classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "01eca895-fe87-46ab-8707-596b04fa5711",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000 1000\n",
      "[-0.50732797 -0.0859023   0.79519987 -0.941882    0.60729342  0.37327327\n",
      " -1.35511381  0.08934151  0.38568075 -0.82372423  0.21789479  1.14796323\n",
      "  0.38797855  0.23849993 -1.66507864  0.39428038 -2.59608648 -0.97139603\n",
      " -0.32160851  0.16779007] 0\n"
     ]
    }
   ],
   "source": [
    "data, targets = make_classification(n_samples=1000)\n",
    "print(len(data), len(targets))\n",
    "print(data[0], targets[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2b02345b-d45e-4a11-bdde-90cab0f7441a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32 torch.int64\n"
     ]
    }
   ],
   "source": [
    "print(torch.float, torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7bc57990-a07d-48d7-83ab-d1731eb51be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, data, targets):\n",
    "        self.data = data\n",
    "        self.targets = targets\n",
    "    def __len__(self):\n",
    "        return self.data.shape[0]\n",
    "    def __getitem__(self, idx):\n",
    "        return {\n",
    "            'x': torch.tensor(self.data[idx, :], dtype=torch.float),\n",
    "            'y': torch.tensor(self.targets[idx], dtype=torch.long)\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "02d4f5aa-94a9-470d-8c55-33eb30ba7410",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': tensor([-0.5073, -0.0859,  0.7952, -0.9419,  0.6073,  0.3733, -1.3551,  0.0893,\n",
       "          0.3857, -0.8237,  0.2179,  1.1480,  0.3880,  0.2385, -1.6651,  0.3943,\n",
       "         -2.5961, -0.9714, -0.3216,  0.1678]),\n",
       " 'y': tensor(0)}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = CustomDataset(data, targets)\n",
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d447e5d3-d52b-49d0-aa6c-578e98640ca9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "30538425-59c6-4cc3-8383-7da0d6741059",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(dataset, batch_size=32, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b367568b-ce44-4e57-a2c0-80ef9ec0a684",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 20])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(train_dataloader))['x'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "503d9dfa-973e-4ccb-a37d-99764c4c4a7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 20]) torch.Size([32])\n"
     ]
    }
   ],
   "source": [
    "for batch in train_dataloader:\n",
    "    batch_x = batch['x']\n",
    "    batch_y = batch['y']\n",
    "    print(batch_x.shape, batch_y.shape)\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
